According to the World Health Organization, stroke is the second leading cause of death globally and, according to the Heart Disease and Stroke Statistics 2019 report, stroke is the fifth leading cause of death in the United States. Additionally, the most recent HDSS report shows that someone has a stroke in the United States every 40 seconds and someone dies from a stroke every 3.5 minutes. Due to the prevalence and seriousness of the heart disease condition, being able to predict one’s likelihood of suffering from a stroke prior could be helpful in assessing risk and evaluating treatment plans accordingly.
Using the “healthcare-dataset-stroke-data” from Kaggle, we are curious to see which variables are associated with, first, a patient having a stroke and then second, if we can find a model to predict whether a patient will or will not have a stroke. While previous research shows that age, heart disease, average glucose level and hypertension are most important factors for stroke prediction, the dataset we are using contains all of these and also many other variables that may reveal interesting patterns.
The data we are using for this project is from Kaggle and contains 5,100 observations with twelve attributes: id, gender, age, if hypertension is present or not, if heart disease is present or not, if they have ever been married, what type of work they do, where they reside (rural or urban), their average glucose level, BMI, their smoking status and whether or not they had a stroke. Each row of data corresponds to one patient:
id: unique identifier
gender: “Male”, “Female” or “Other”
age: age of the patient
hypertension: 0 if the patient doesn’t have hypertension, 1 if the patient has hypertension
heart_disease: 0 if the patient doesn’t have any heart diseases, 1 if the patient has a heart disease
ever_married: “No” or “Yes”
work_type: “children”, “Govt_jov”, “Never_worked”, “Private” or “Self-employed”
Residence_type: “Rural” or “Urban”
avg_glucose_level: average glucose level in blood
bmi: body mass index
smoking_status: “formerly smoked”, “never smoked”, “smokes” or “Unknown”*
stroke: 1 if the patient had a stroke or 0 if not
# data
library(dplyr)
library(ggplot2)
library(Amelia)
library(corrplot)
library(corrgram)
# model
library(caTools)
library(e1071)
library(caret)
library(ROSE)
library(Metrics)
library(class)
library(tidymodels)
library(glmnet)
# read the data
data <- read.csv("healthcare-dataset-stroke-data.csv")
# view the data
head(data)
## id gender age hypertension heart_disease ever_married work_type
## 1 9046 Male 67 0 1 Yes Private
## 2 51676 Female 61 0 0 Yes Self-employed
## 3 31112 Male 80 0 1 Yes Private
## 4 60182 Female 49 0 0 Yes Private
## 5 1665 Female 79 1 0 Yes Self-employed
## 6 56669 Male 81 0 0 Yes Private
## Residence_type avg_glucose_level bmi smoking_status stroke
## 1 Urban 228.69 36.6 formerly smoked 1
## 2 Rural 202.21 N/A never smoked 1
## 3 Rural 105.92 32.5 never smoked 1
## 4 Urban 171.23 34.4 smokes 1
## 5 Rural 174.12 24 never smoked 1
## 6 Urban 186.21 29 formerly smoked 1
str(data)
## 'data.frame': 5110 obs. of 12 variables:
## $ id : int 9046 51676 31112 60182 1665 56669 53882 10434 27419 60491 ...
## $ gender : chr "Male" "Female" "Male" "Female" ...
## $ age : num 67 61 80 49 79 81 74 69 59 78 ...
## $ hypertension : int 0 0 0 0 1 0 1 0 0 0 ...
## $ heart_disease : int 1 0 1 0 0 0 1 0 0 0 ...
## $ ever_married : chr "Yes" "Yes" "Yes" "Yes" ...
## $ work_type : chr "Private" "Self-employed" "Private" "Private" ...
## $ Residence_type : chr "Urban" "Rural" "Rural" "Urban" ...
## $ avg_glucose_level: num 229 202 106 171 174 ...
## $ bmi : chr "36.6" "N/A" "32.5" "34.4" ...
## $ smoking_status : chr "formerly smoked" "never smoked" "never smoked" "smokes" ...
## $ stroke : int 1 1 1 1 1 1 1 1 1 1 ...
summary(data)
## id gender age hypertension
## Min. : 67 Length:5110 Min. : 0.08 Min. :0.00000
## 1st Qu.:17741 Class :character 1st Qu.:25.00 1st Qu.:0.00000
## Median :36932 Mode :character Median :45.00 Median :0.00000
## Mean :36518 Mean :43.23 Mean :0.09746
## 3rd Qu.:54682 3rd Qu.:61.00 3rd Qu.:0.00000
## Max. :72940 Max. :82.00 Max. :1.00000
## heart_disease ever_married work_type Residence_type
## Min. :0.00000 Length:5110 Length:5110 Length:5110
## 1st Qu.:0.00000 Class :character Class :character Class :character
## Median :0.00000 Mode :character Mode :character Mode :character
## Mean :0.05401
## 3rd Qu.:0.00000
## Max. :1.00000
## avg_glucose_level bmi smoking_status stroke
## Min. : 55.12 Length:5110 Length:5110 Min. :0.00000
## 1st Qu.: 77.25 Class :character Class :character 1st Qu.:0.00000
## Median : 91.89 Mode :character Mode :character Median :0.00000
## Mean :106.15 Mean :0.04873
## 3rd Qu.:114.09 3rd Qu.:0.00000
## Max. :271.74 Max. :1.00000
# check NA values
any(is.na(data)) # it shows there are no missing value
## [1] FALSE
# however there are N/A in bmi, convert them to NA values
data[data == 'N/A'] <- NA
missmap(data, col=c("yellow", "black"), legend=FALSE) # there are missing values on bmi
table(is.na(data)) # there are 201 missing values
##
## FALSE TRUE
## 61119 201
# convert bmi data type to numeric
data$bmi <- as.numeric(data$bmi)
# plot bmi
hist(data$bmi)
boxplot(data$bmi)
# drop NA values
data <- na.omit(data)
any(is.na(data)) # check again, there is no NA value now
## [1] FALSE
# drop the id column
data <- data[-1]
head(data)
## gender age hypertension heart_disease ever_married work_type
## 1 Male 67 0 1 Yes Private
## 3 Male 80 0 1 Yes Private
## 4 Female 49 0 0 Yes Private
## 5 Female 79 1 0 Yes Self-employed
## 6 Male 81 0 0 Yes Private
## 7 Male 74 1 1 Yes Private
## Residence_type avg_glucose_level bmi smoking_status stroke
## 1 Urban 228.69 36.6 formerly smoked 1
## 3 Rural 105.92 32.5 never smoked 1
## 4 Urban 171.23 34.4 smokes 1
## 5 Rural 174.12 24.0 never smoked 1
## 6 Urban 186.21 29.0 formerly smoked 1
## 7 Rural 70.09 27.4 never smoked 1
# data transformation
str(data)
## 'data.frame': 4909 obs. of 11 variables:
## $ gender : chr "Male" "Male" "Female" "Female" ...
## $ age : num 67 80 49 79 81 74 69 78 81 61 ...
## $ hypertension : int 0 0 0 1 0 1 0 0 1 0 ...
## $ heart_disease : int 1 1 0 0 0 1 0 0 0 1 ...
## $ ever_married : chr "Yes" "Yes" "Yes" "Yes" ...
## $ work_type : chr "Private" "Private" "Private" "Self-employed" ...
## $ Residence_type : chr "Urban" "Rural" "Urban" "Rural" ...
## $ avg_glucose_level: num 229 106 171 174 186 ...
## $ bmi : num 36.6 32.5 34.4 24 29 27.4 22.8 24.2 29.7 36.8 ...
## $ smoking_status : chr "formerly smoked" "never smoked" "smokes" "never smoked" ...
## $ stroke : int 1 1 1 1 1 1 1 1 1 1 ...
# convert character data type to factor
data <- data %>% mutate(across(where(is.character),factor))
# convert hypertension, heart_disease, stroke data type from integer to factor
data$hypertension <- as.factor(data$hypertension)
data$heart_disease <- as.factor(data$heart_disease)
data$stroke <- as.factor(data$stroke)
# binning numeric valuables
# age
ggplot(data, aes(age,y=..density..)) +
geom_histogram(binwidth=1,
color="black",
fill="#02bcfa",
alpha=0.5) +
geom_density() + labs(title="Age Distribution")
## Warning: The dot-dot notation (`..density..`) was deprecated in ggplot2 3.4.0.
## ℹ Please use `after_stat(density)` instead.
boxplot(data$age)
summary(data$age)
## Min. 1st Qu. Median Mean 3rd Qu. Max.
## 0.08 25.00 44.00 42.87 60.00 82.00
# binning age with quantile: 25, 44, 60, 82
data$age <- cut(data$age,
breaks = c(0, 25, 44, 60, 82),
labels=c('young', 'grown', 'mature', 'old'))
# avg glucose level
ggplot(data, aes(avg_glucose_level, y=..density..)) +
geom_histogram(color="black",
fill="#02bcfa",
alpha=0.5) +
geom_density() +
labs(title="Average Glucose Level Distribution")
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
boxplot(data$avg_glucose_level)
summary(data$avg_glucose_level)
## Min. 1st Qu. Median Mean 3rd Qu. Max.
## 55.12 77.07 91.68 105.31 113.57 271.74
# binning avg glucose level based on the information on website:
# https://my.clevelandclinic.org/health/diagnostics/12363-blood-glucose-test#:~:text=What%20is%20a%20normal%20glucose,can%20be%20%E2%80%9Cnormal%E2%80%9D%20too.
group_glucose <- function(level){
res <- level
for (i in 1:length(level)){
if (level[i] <= 70){
res[i] <- "low"
} else if (level[i] > 70 & level[i] <= 99) {
res[i] <- "normal"
} else if (level[i] > 100 & level[i] <= 125) {
res[i] <- "prediabetes"
} else {
res[i] <- "diabetes"
}
}
return(res)
}
# apply group_glucose function
data$avg_glucose_level <- group_glucose(data$avg_glucose_level)
# convert avg_glucose_level data type to factor
data$avg_glucose_level <- as.factor(data$avg_glucose_level)
# levels of data$avg_glucose_level are in the wrong order
levels(data$avg_glucose_level)
## [1] "diabetes" "low" "normal" "prediabetes"
# reorder the levels of data$avg_glucose_level
data$avg_glucose_level <- factor(data$avg_glucose_level, levels = c("low", "normal", "prediabetes", "diabetes"))
# check again
levels(data$avg_glucose_level)
## [1] "low" "normal" "prediabetes" "diabetes"
# bmi
ggplot(data, aes(bmi)) +
geom_histogram(color="black",
fill="#02bcfa",
alpha=0.5) +
geom_density() +
labs(title="BMI Distribution")
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
boxplot(data$bmi)
summary(data$bmi)
## Min. 1st Qu. Median Mean 3rd Qu. Max.
## 10.30 23.50 28.10 28.89 33.10 97.60
# binning BMI based on the information on the CDC website:
# https://www.cdc.gov/healthyweight/assessing/index.html#:~:text=If%20your%20BMI%20is%20less,falls%20within%20the%20obese%20range.
group_bmi <- function(bmi){
res <- bmi
for (i in 1:length(bmi)){
if (bmi[i] < 18.5){
res[i] <- "underweight"
} else if (bmi[i] >= 18.5 & bmi[i] <= 24.9) {
res[i] <- "normal"
} else if (bmi[i] >= 25.0 & bmi[i] <= 29.9) {
res[i] <- "overweight"
} else {
res[i] <- "obese"
}
}
return(res)
}
# apply group_bmi function
data$bmi <- group_bmi(data$bmi)
# convert bmi data type to factor
data$bmi <- as.factor(data$bmi)
# levels of data$bmi are in the wrong order
levels(data$bmi)
## [1] "normal" "obese" "overweight" "underweight"
# reorder the levels of bmi
data$bmi <- factor(data$bmi, levels = c("underweight", "normal", "overweight", "obese"))
# check the structure
str(data)
## 'data.frame': 4909 obs. of 11 variables:
## $ gender : Factor w/ 3 levels "Female","Male",..: 2 2 1 1 2 2 1 1 1 1 ...
## $ age : Factor w/ 4 levels "young","grown",..: 4 4 3 4 4 4 4 4 4 4 ...
## $ hypertension : Factor w/ 2 levels "0","1": 1 1 1 2 1 2 1 1 2 1 ...
## $ heart_disease : Factor w/ 2 levels "0","1": 2 2 1 1 1 2 1 1 1 2 ...
## $ ever_married : Factor w/ 2 levels "No","Yes": 2 2 2 2 2 2 1 2 2 2 ...
## $ work_type : Factor w/ 5 levels "children","Govt_job",..: 4 4 4 5 4 4 4 4 4 2 ...
## $ Residence_type : Factor w/ 2 levels "Rural","Urban": 2 1 2 1 2 1 2 2 1 1 ...
## $ avg_glucose_level: Factor w/ 4 levels "low","normal",..: 4 3 4 4 4 2 2 1 2 3 ...
## $ bmi : Factor w/ 4 levels "underweight",..: 4 4 4 2 3 3 2 2 3 4 ...
## $ smoking_status : Factor w/ 4 levels "formerly smoked",..: 1 2 3 2 1 2 2 4 2 3 ...
## $ stroke : Factor w/ 2 levels "0","1": 2 2 2 2 2 2 2 2 2 2 ...
# check the correlation
# convert factor variables to numeric variables
data_num <- data %>% mutate(across(where(is.factor),as.numeric))
str(data_num)
## 'data.frame': 4909 obs. of 11 variables:
## $ gender : num 2 2 1 1 2 2 1 1 1 1 ...
## $ age : num 4 4 3 4 4 4 4 4 4 4 ...
## $ hypertension : num 1 1 1 2 1 2 1 1 2 1 ...
## $ heart_disease : num 2 2 1 1 1 2 1 1 1 2 ...
## $ ever_married : num 2 2 2 2 2 2 1 2 2 2 ...
## $ work_type : num 4 4 4 5 4 4 4 4 4 2 ...
## $ Residence_type : num 2 1 2 1 2 1 2 2 1 1 ...
## $ avg_glucose_level: num 4 3 4 4 4 2 2 1 2 3 ...
## $ bmi : num 4 4 4 2 3 3 2 2 3 4 ...
## $ smoking_status : num 1 2 3 2 1 2 2 4 2 3 ...
## $ stroke : num 2 2 2 2 2 2 2 2 2 2 ...
# correlation and corrplot
(cor <- cor(data_num))
## gender age hypertension heart_disease
## gender 1.000000000 -0.01692503 0.021578286 0.082711652
## age -0.016925028 1.00000000 0.270583325 0.251216361
## hypertension 0.021578286 0.27058332 1.000000000 0.115990991
## heart_disease 0.082711652 0.25121636 0.115990991 1.000000000
## ever_married -0.037236693 0.66476920 0.162406260 0.111245121
## work_type -0.072538268 0.45266958 0.124654706 0.092144819
## Residence_type -0.005013763 0.01457606 -0.001074146 -0.002361744
## avg_glucose_level 0.040858382 0.14965716 0.124198217 0.106454582
## bmi 0.020621827 0.38468598 0.159444545 0.085086762
## smoking_status 0.038252248 -0.34128144 -0.132831660 -0.071396924
## stroke 0.006757363 0.21921811 0.142514606 0.137937788
## ever_married work_type Residence_type avg_glucose_level
## gender -0.037236693 -0.0725382684 -0.0050137626 0.040858382
## age 0.664769200 0.4526695779 0.0145760598 0.149657159
## hypertension 0.162406260 0.1246547061 -0.0010741462 0.124198217
## heart_disease 0.111245121 0.0921448190 -0.0023617439 0.106454582
## ever_married 1.000000000 0.4259143556 0.0049891711 0.096121554
## work_type 0.425914356 1.0000000000 -0.0008827106 0.054018824
## Residence_type 0.004989171 -0.0008827106 1.0000000000 -0.007651788
## avg_glucose_level 0.096121554 0.0540188244 -0.0076517877 1.000000000
## bmi 0.404389675 0.4020328198 -0.0074476403 0.112484272
## smoking_status -0.310702330 -0.3444032458 0.0027191093 -0.065479200
## stroke 0.105089144 0.0797450467 0.0060314265 0.094897697
## bmi smoking_status stroke
## gender 0.02062183 0.038252248 0.006757363
## age 0.38468598 -0.341281444 0.219218106
## hypertension 0.15944454 -0.132831660 0.142514606
## heart_disease 0.08508676 -0.071396924 0.137937788
## ever_married 0.40438967 -0.310702330 0.105089144
## work_type 0.40203282 -0.344403246 0.079745047
## Residence_type -0.00744764 0.002719109 0.006031426
## avg_glucose_level 0.11248427 -0.065479200 0.094897697
## bmi 1.00000000 -0.282826306 0.064070432
## smoking_status -0.28282631 1.000000000 -0.075919784
## stroke 0.06407043 -0.075919784 1.000000000
# corrplot
corrplot(cor, method = "color")
# corrgram
corrgram(data_num, order = TRUE,
lower.panel = panel.shade,
upper.panel = panel.pie,
text.panel = panel.txt)
# Age Group Distribution
ggplot(data, aes(age)) +
geom_bar(aes(fill = age)) +
scale_fill_brewer(palette = "Set2") +
geom_text(stat='count', aes(label=..count..), vjust=-0.3) +
labs(x = "Age", y = "Count", title ="Age Group Distribution")
# age & stroke
ggplot(data, aes(age)) +
geom_bar(aes(fill = stroke)) +
scale_fill_brewer(palette = "Set2") +
facet_grid(cols = vars(stroke)) +
geom_text(stat='count', aes(label=..count..), vjust=-0.3) +
labs(x = "Age", y = "Count", title ="Age Group Distribution with Class Label")
# normalize the height
ggplot(data, aes(age)) +
geom_bar(aes(fill = stroke),
position = "fill",
alpha=0.8) +
labs(x = "Age", y = "Scaled Count", title ="Age Distribution with Normalize Height")
# Gender Distribution
ggplot(data, aes(gender)) +
geom_bar(aes(fill = gender)) +
scale_fill_brewer(palette = "Set2") +
geom_text(stat='count', aes(label=..count..), vjust=-0.3) +
labs(x = "Gender", y = "Count", title ="Gender Distribution")
# gender & stroke
ggplot(data, aes(gender)) +
geom_bar(aes(fill = stroke)) +
scale_fill_brewer(palette = "Set2") +
facet_grid(cols = vars(stroke)) +
geom_text(stat='count', aes(label=..count..), vjust=-0.3) +
labs(x = "Gender", y = "Count", title ="Gender Distribution with Class Label")
# normalize the height
ggplot(data, aes(gender)) +
geom_bar(aes(fill = stroke),
position = "fill",
alpha=0.8) +
labs(x = "Gender", y = "Scaled Count", title ="Gender Distribution with Normalize Height")
# Hypertension Distribution
ggplot(data, aes(hypertension)) +
geom_bar(aes(fill = hypertension)) +
scale_fill_brewer(palette = "Set2") +
geom_text(stat='count', aes(label=..count..), vjust=-0.3) +
labs(x = "Hypertension", y = "Count", title ="Hypertension Distribution")
# hypertension & stroke
ggplot(data, aes(hypertension)) +
geom_bar(aes(fill = stroke)) +
scale_fill_brewer(palette = "Set2") +
facet_grid(cols = vars(stroke)) +
geom_text(stat='count', aes(label=..count..), vjust=-0.3) +
labs(x = "Hypertension", y = "Count", title ="Hypertension Distribution with Class Label")
# normalize the height
ggplot(data, aes(hypertension)) +
geom_bar(aes(fill = stroke),
position = "fill",
alpha=0.8) +
labs(x = "Hypertension", y = "Scaled Count", title ="Hypertension Distribution with Normalize Height")
# Heart Disease Distribution
ggplot(data, aes(heart_disease)) +
geom_bar(aes(fill = heart_disease)) +
scale_fill_brewer(palette = "Set2") +
geom_text(stat='count', aes(label=..count..), vjust=-0.3) +
labs(x = "Heart Disease", y = "Count", title ="Heart Disease Distribution")
# heart_disease & stroke
ggplot(data, aes(heart_disease)) +
geom_bar(aes(fill = stroke)) +
scale_fill_brewer(palette = "Set2") +
facet_grid(cols = vars(stroke)) +
geom_text(stat='count', aes(label=..count..), vjust=-0.3) +
labs(x = "Heart Disease", y = "Count", title ="Heart Disease Distribution with Class Label")
# normalize the height
ggplot(data, aes(heart_disease)) +
geom_bar(aes(fill = stroke),
position = "fill",
alpha=0.8) +
labs(x = "Heart Disease", y = "Scaled Count", title ="Heart Disease Distribution with Normalize Height")
# Marital Status
ggplot(data, aes(ever_married)) +
geom_bar(aes(fill = ever_married)) +
scale_fill_brewer(palette = "Set2") +
geom_text(stat='count', aes(label=..count..), vjust=-0.3) +
labs(x = "Marital Status", y = "Count", title ="Marital Status")
# ever_married & stroke
ggplot(data, aes(ever_married)) +
geom_bar(aes(fill = stroke)) +
scale_fill_brewer(palette = "Set2") +
facet_grid(cols = vars(stroke)) +
geom_text(stat='count', aes(label=..count..), vjust=-0.3) +
labs(x = "Marital Status", y = "Count", title ="Marital Status with Class Label")
# normalize the height
ggplot(data, aes(ever_married)) +
geom_bar(aes(fill = stroke),
position = "fill",
alpha=0.8) +
labs(x = "Marital Status", y = "Scaled Count", title ="Marital Status with Normalize Height")
# Distribution of Work Type
ggplot(data, aes(work_type)) +
geom_bar(aes(fill = work_type)) +
scale_fill_brewer(palette = "Set2") +
geom_text(stat='count', aes(label=..count..), vjust=-0.3) +
labs(x = "Work Type", y = "Count", title ="Distribution of Work Type")
# work_type & stroke
ggplot(data, aes(work_type)) +
geom_bar(aes(fill = stroke)) +
scale_fill_brewer(palette = "Set2") +
facet_grid(cols = vars(stroke)) +
geom_text(stat='count', aes(label=..count..), vjust=-0.3) +
labs(x = "Work Type", y = "Count", title ="Distribution of Work Type with Class Label")
# normalize the height
ggplot(data, aes(work_type)) +
geom_bar(aes(fill = stroke),
position = "fill",
alpha=0.8) +
labs(x = "Work Type", y = "Scaled Count", title ="Distribution of Work Type with Normalize Height")
# Distribution of Residence Type
ggplot(data, aes(Residence_type)) +
geom_bar(aes(fill = Residence_type)) +
scale_fill_brewer(palette = "Set2") +
geom_text(stat='count', aes(label=..count..), vjust=-0.3) +
labs(x = "Residence Type", y = "Count", title ="Distribution of Residence Type")
# Residence_type & stroke
ggplot(data, aes(Residence_type)) +
geom_bar(aes(fill = stroke)) +
scale_fill_brewer(palette = "Set2") +
facet_grid(cols = vars(stroke)) +
geom_text(stat='count', aes(label=..count..), vjust=-0.3) +
labs(x = "Residence Type", y = "Count", title ="Distribution of Residence Type with Class Label")
# normalize the height
ggplot(data, aes(Residence_type)) +
geom_bar(aes(fill = stroke),
position = "fill",
alpha=0.8) +
labs(x = "Residence Type", y = "Scaled Count", title ="Distribution of Residence Type with Normalize Height")
# Group of Average Glucose Level
ggplot(data, aes(avg_glucose_level)) +
geom_bar(aes(fill = avg_glucose_level)) +
scale_fill_brewer(palette = "Set2") +
geom_text(stat='count', aes(label=..count..), vjust=-0.3) +
labs(x = "Average Glucose Level", y = "Count", title ="Group of Average Glucose Level")
# avg_glucose_level & stroke
ggplot(data, aes(avg_glucose_level)) +
geom_bar(aes(fill = stroke)) +
scale_fill_brewer(palette = "Set2") +
facet_grid(cols = vars(stroke)) +
geom_text(stat='count', aes(label=..count..), vjust=-0.3) +
labs(x = "Average Glucose Level", y = "Count", title ="Distribution of Average Glucose Level Group with Class Label")
# normalize the height
ggplot(data, aes(avg_glucose_level)) +
geom_bar(aes(fill = stroke),
position = "fill",
alpha=0.8) +
labs(x = "Average Glucose Level", y = "Scaled Count", title ="Group of Average Glucose Level with Normalize Height")
# Distribution of BMI Group
ggplot(data, aes(bmi)) +
geom_bar(aes(fill = bmi)) +
scale_fill_brewer(palette = "Set2") +
geom_text(stat='count', aes(label=..count..), vjust=-0.3) +
labs(x = "BMI Group", y = "Count", title ="Distribution of BMI Group")
# bmi & stroke
ggplot(data, aes(bmi)) +
geom_bar(aes(fill = stroke)) +
scale_fill_brewer(palette = "Set2") +
facet_grid(cols = vars(stroke)) +
geom_text(stat='count', aes(label=..count..), vjust=-0.3) +
labs(x = "BMI Group", y = "Count", title ="Distribution of BMI Group with Class Label")
# normalize the height
ggplot(data, aes(bmi)) +
geom_bar(aes(fill = stroke),
position = "fill",
alpha=0.8) +
labs(x = "BMI Group", y = "Scaled Count", title ="BMI Group with Normalize Height")
# Distribution of Smoking Status
ggplot(data, aes(smoking_status)) +
geom_bar(aes(fill = smoking_status)) +
scale_fill_brewer(palette = "Set2") +
geom_text(stat='count', aes(label=..count..), vjust=-0.3) +
labs(x = "Smoking Status", y = "Count", title ="Distribution of Smoking Status")
# smoking_status & stroke
ggplot(data, aes(smoking_status)) +
geom_bar(aes(fill = stroke)) +
scale_fill_brewer(palette = "Set2") +
facet_grid(cols = vars(stroke)) +
geom_text(stat='count', aes(label=..count..), vjust=-0.3) +
labs(x = "Smoking Status", y = "Count", title ="Distribution of Smoking Status with Class Label")
# normalize the height
ggplot(data, aes(smoking_status)) +
geom_bar(aes(fill = stroke),
position = "fill",
alpha=0.8) +
labs(x = "Smoking Status", y = "Scaled Count", title ="Distribution of Smoking Status with Normalize Height")
# age, avg_glucose_level
ggplot(data, aes(age)) +
geom_bar(alpha = 0.8, aes(fill = avg_glucose_level)) +
facet_grid(rows = vars(avg_glucose_level)) +
scale_fill_brewer(palette = "Reds") +
geom_text(stat='count', aes(label=..count..)) +
labs(x = "Age Group", y = "Count", title ="Distribution of Average Glucose Levels in Different Age Groups")
# age, hypertension
ggplot(data, aes(age)) +
geom_bar(alpha = 0.8, aes(fill = hypertension)) +
facet_grid(rows = vars(hypertension)) +
scale_fill_manual(values = c("skyblue", "royalblue", "blue", "navy")) +
geom_text(stat='count', aes(label=..count..)) +
labs(x = "Age Group", y = "Count", title ="Hypertension Status in Different Age Groups")
# age, heart_disease
ggplot(data, aes(age)) +
geom_bar(alpha = 0.8, aes(fill = heart_disease)) +
facet_grid(rows = vars(heart_disease)) +
scale_fill_manual(values = c("skyblue", "royalblue", "blue", "navy")) +
geom_text(stat='count', aes(label=..count..)) +
labs(x = "Age Group", y = "Count", title ="Heart Disease Status in Different Age Groups")
# age, bmi
ggplot(data, aes(age)) +
geom_bar(alpha = 0.8, aes(fill = bmi)) +
facet_grid(rows = vars(bmi)) +
scale_fill_brewer(palette = "Reds") +
geom_text(stat='count', aes(label=..count..)) +
labs(x = "Age Group", y = "Count", title ="BMI Group in Different Age Groups")
# avg_glucose_level, bmi, stroke
ggplot(data, aes(bmi, avg_glucose_level)) +
geom_jitter(alpha = 0.6, aes(color = stroke), size =1) +
facet_grid(rows = vars(stroke)) +
labs(x = "BMI Group", y = "Group of Average Glucose Level", title ="Distribution of BMI & Average Glucose Level with Class Label")
# hypertension, avg_glucose_level, bmi
ggplot(data, aes(bmi, avg_glucose_level)) +
geom_jitter(alpha = 0.6, aes(color = hypertension), size =1) +
facet_grid(rows = vars(hypertension)) +
labs(x = "BMI Group", y = "Group of Average Glucose Level", title ="Distribution of BMI & Average Glucose Level with Different Hypertension Status")
# drop uncorrelated attributes: gender, Residence_type
data_drop <- select(data, -gender, -Residence_type)
# Random Over-Sampling
# move class label to 1st row on dataset
data_md <- data_drop[c(9:1)]
# over sampling data
data_os <- ovun.sample(stroke~., data=data_md, method = "over", p = 0.5, seed = 1)
# check the data after over sampling
str(data_os)
## List of 3
## $ Call : language ovun.sample(formula = stroke ~ ., data = data_md, method = "over", p = 0.5, seed = 1)
## $ method: chr "over"
## $ data :'data.frame': 9367 obs. of 9 variables:
## ..$ stroke : Factor w/ 2 levels "0","1": 1 1 1 1 1 1 1 1 1 1 ...
## ..$ smoking_status : Factor w/ 4 levels "formerly smoked",..: 4 2 4 1 4 4 1 2 3 2 ...
## ..$ bmi : Factor w/ 4 levels "underweight",..: 1 4 1 4 2 4 1 3 4 4 ...
## ..$ avg_glucose_level: Factor w/ 4 levels "low","normal",..: 2 2 3 1 4 4 2 4 2 4 ...
## ..$ work_type : Factor w/ 5 levels "children","Govt_job",..: 1 4 4 4 3 4 4 5 4 5 ...
## ..$ ever_married : Factor w/ 2 levels "No","Yes": 1 2 1 2 1 2 2 2 2 2 ...
## ..$ heart_disease : Factor w/ 2 levels "0","1": 1 1 1 1 1 1 1 2 1 1 ...
## ..$ hypertension : Factor w/ 2 levels "0","1": 1 2 1 1 1 1 1 1 1 2 ...
## ..$ age : Factor w/ 4 levels "young","grown",..: 1 3 1 4 1 3 3 4 2 4 ...
## - attr(*, "class")= chr "ovun.sample"
summary(data_os)
##
## Call:
## ovun.sample(formula = stroke ~ ., data = data_md, method = "over",
## p = 0.5, seed = 1)
##
## Summary of data balanced by oversampling
##
## stroke smoking_status bmi avg_glucose_level
## 0:4700 formerly smoked:1981 underweight: 348 low :1159
## 1:4667 never smoked :3649 normal :2012 normal :3695
## smokes :1580 overweight :2931 prediabetes:1602
## Unknown :2157 obese :4076 diabetes :2911
##
## work_type ever_married heart_disease hypertension age
## children : 687 No :2184 0:8259 0:7602 young :1280
## Govt_job :1184 Yes:7183 1:1108 1:1765 grown :1360
## Never_worked : 22 mature:2380
## Private :5563 old :4347
## Self-employed:1911
table(data_os$data$stroke)
##
## 0 1
## 4700 4667
# move class label (stroke) back to last row
data_os <- data_os$data[c(9:1)]
# Train and Test Split
set.seed(101)
split <- sample.split(data_os$stroke, SplitRatio = 0.7)
train <- subset(data_os, split == TRUE)
test <- subset(data_os, split == FALSE)
# SVM
model_svm <- svm(stroke ~ ., data = train)
# check the model
summary(model_svm)
##
## Call:
## svm(formula = stroke ~ ., data = train)
##
##
## Parameters:
## SVM-Type: C-classification
## SVM-Kernel: radial
## cost: 1
##
## Number of Support Vectors: 3425
##
## ( 1692 1733 )
##
##
## Number of Classes: 2
##
## Levels:
## 0 1
# use the model on test data to predict our label (stroke)
pred_svm <- predict(model_svm, test[1:8])
# check the model performance
confusionMatrix(pred_svm,
factor(test$stroke),
mode = "everything",
positive = "1")
## Confusion Matrix and Statistics
##
## Reference
## Prediction 0 1
## 0 975 189
## 1 435 1211
##
## Accuracy : 0.7779
## 95% CI : (0.7621, 0.7932)
## No Information Rate : 0.5018
## P-Value [Acc > NIR] : < 2.2e-16
##
## Kappa : 0.5561
##
## Mcnemar's Test P-Value : < 2.2e-16
##
## Sensitivity : 0.8650
## Specificity : 0.6915
## Pos Pred Value : 0.7357
## Neg Pred Value : 0.8376
## Precision : 0.7357
## Recall : 0.8650
## F1 : 0.7951
## Prevalence : 0.4982
## Detection Rate : 0.4310
## Detection Prevalence : 0.5858
## Balanced Accuracy : 0.7782
##
## 'Positive' Class : 1
##
# Parameter tuning
# sampling method: 10-fold cross validation
# It takes a long time, so I comment here
# tune.results <- tune(svm, train.x = stroke ~ .,
# data = train,
# kernel = 'radial',
# ranges = list(cost = c(1,10),
# gamma = c(0.1,1)))
# tune.results
# set cost = 10, gamma = 1
model_svm <- svm(stroke ~ ., data = train,
kernel = 'radial',
cost = 10,
gamma = 1)
# apply the tuned SVM model on test data to predict class label (stroke)
pred_svm <- predict(model_svm,test[1:8])
# KNN
# convert data type from factor to numeric
data_os_num <- data_os %>% mutate(across(where(is.factor),as.numeric))
data_os_num$stroke <- factor(data_os_num$stroke)
str(data_os_num)
## 'data.frame': 9367 obs. of 9 variables:
## $ age : num 1 3 1 4 1 3 3 4 2 4 ...
## $ hypertension : num 1 2 1 1 1 1 1 1 1 2 ...
## $ heart_disease : num 1 1 1 1 1 1 1 2 1 1 ...
## $ ever_married : num 1 2 1 2 1 2 2 2 2 2 ...
## $ work_type : num 1 4 4 4 3 4 4 5 4 5 ...
## $ avg_glucose_level: num 2 2 3 1 4 4 2 4 2 4 ...
## $ bmi : num 1 4 1 4 2 4 1 3 4 4 ...
## $ smoking_status : num 4 2 4 1 4 4 1 2 3 2 ...
## $ stroke : Factor w/ 2 levels "1","2": 1 1 1 1 1 1 1 1 1 1 ...
# standardize the dataset except class label (stroke)
data_std <- scale(data_os_num[1:8])
head(data_std)
## age hypertension heart_disease ever_married work_type
## 1 -1.90428262 -0.4818205 -0.3662545 -1.813441 -2.3945821
## 2 -0.04243665 2.0752403 -0.3662545 0.551379 0.2379499
## 3 -1.90428262 -0.4818205 -0.3662545 -1.813441 0.2379499
## 4 0.88848633 -0.4818205 -0.3662545 0.551379 0.2379499
## 5 -1.90428262 -0.4818205 -0.3662545 -1.813441 -0.6395607
## 6 -0.04243665 -0.4818205 -0.3662545 0.551379 0.2379499
## avg_glucose_level bmi smoking_status
## 1 -0.6404483 -2.4341304 1.4905203
## 2 -0.6404483 0.9685906 -0.3935231
## 3 0.3171063 -2.4341304 1.4905203
## 4 -1.5980030 0.9685906 -1.3355448
## 5 1.2746609 -1.2998901 1.4905203
## 6 1.2746609 0.9685906 1.4905203
# check variance
var(data_std[,8])
## [1] 1
# add label column (stroke) back
data_knn <- cbind(data_std, data_os_num[9])
head(data_knn)
## age hypertension heart_disease ever_married work_type
## 1 -1.90428262 -0.4818205 -0.3662545 -1.813441 -2.3945821
## 2 -0.04243665 2.0752403 -0.3662545 0.551379 0.2379499
## 3 -1.90428262 -0.4818205 -0.3662545 -1.813441 0.2379499
## 4 0.88848633 -0.4818205 -0.3662545 0.551379 0.2379499
## 5 -1.90428262 -0.4818205 -0.3662545 -1.813441 -0.6395607
## 6 -0.04243665 -0.4818205 -0.3662545 0.551379 0.2379499
## avg_glucose_level bmi smoking_status stroke
## 1 -0.6404483 -2.4341304 1.4905203 1
## 2 -0.6404483 0.9685906 -0.3935231 1
## 3 0.3171063 -2.4341304 1.4905203 1
## 4 -1.5980030 0.9685906 -1.3355448 1
## 5 1.2746609 -1.2998901 1.4905203 1
## 6 1.2746609 0.9685906 1.4905203 1
# train and test split for KNN model
set.seed(101)
split_knn <- sample.split(data_knn$stroke, SplitRatio = 0.7)
train_knn <- subset(data_knn, split_knn == TRUE)
test_knn <- subset(data_knn, split_knn == FALSE)
# build KNN model
pred_knn <- knn(train_knn[1:8],
test_knn[1:8],
train_knn$stroke,
k = 1)
# check the error rate
er_knn <- mean(test_knn$stroke != pred_knn)
er_knn
## [1] 0.09644128
# Parameter tuning
for (i in 1:10){
set.seed(101)
pred_knn <- knn(train_knn[1:8],
test_knn[1:8],
train_knn$stroke,
k=i)
er_knn[i] <- mean(test_knn$stroke != pred_knn)
}
# elbow method
k <- 1:10
(df <- data.frame(er_knn, k))
## er_knn k
## 1 0.09644128 1
## 2 0.10640569 2
## 3 0.11886121 3
## 4 0.12669039 4
## 5 0.13594306 5
## 6 0.14555160 6
## 7 0.15053381 7
## 8 0.15836299 8
## 9 0.16298932 9
## 10 0.16548043 10
ggplot(df, aes(k, er_knn)) +
geom_point() +
geom_line(lty="dotted",
color="blue")
# set k = 1
pred_knn <- knn(train_knn[1:8],
test_knn[1:8],
train_knn$stroke,
k = 1)
# Logistic Regression
model_log <- glm(formula = stroke ~ .,
family = binomial(logit),
data = train)
summary(model_log)
##
## Call:
## glm(formula = stroke ~ ., family = binomial(logit), data = train)
##
## Deviance Residuals:
## Min 1Q Median 3Q Max
## -2.36507 -0.64837 -0.00011 0.82444 2.61605
##
## Coefficients:
## Estimate Std. Error z value Pr(>|z|)
## (Intercept) -4.96714 0.47134 -10.538 < 2e-16 ***
## agegrown 15.64066 201.13166 0.078 0.938016
## agemature 17.40326 201.13165 0.087 0.931048
## ageold 18.36437 201.13165 0.091 0.927250
## hypertension1 0.85271 0.08434 10.110 < 2e-16 ***
## heart_disease1 0.56515 0.10699 5.282 1.27e-07 ***
## ever_marriedYes -0.13441 0.10377 -1.295 0.195253
## work_typeGovt_job -14.53861 201.13186 -0.072 0.942376
## work_typeNever_worked -14.15892 945.74134 -0.015 0.988055
## work_typePrivate -14.25017 201.13184 -0.071 0.943517
## work_typeSelf-employed -14.39587 201.13185 -0.072 0.942941
## avg_glucose_levelnormal 0.20099 0.10210 1.969 0.048999 *
## avg_glucose_levelprediabetes 0.20382 0.11646 1.750 0.080104 .
## avg_glucose_leveldiabetes 0.62338 0.10666 5.845 5.08e-09 ***
## bminormal 1.49620 0.41912 3.570 0.000357 ***
## bmioverweight 1.55694 0.41878 3.718 0.000201 ***
## bmiobese 1.54556 0.41865 3.692 0.000223 ***
## smoking_statusnever smoked -0.15925 0.08093 -1.968 0.049086 *
## smoking_statussmokes 0.22365 0.09910 2.257 0.024020 *
## smoking_statusUnknown 0.03292 0.09847 0.334 0.738190
## ---
## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
##
## (Dispersion parameter for binomial family taken to be 1)
##
## Null deviance: 9089.9 on 6556 degrees of freedom
## Residual deviance: 6319.6 on 6537 degrees of freedom
## AIC: 6359.6
##
## Number of Fisher Scoring iterations: 16
# apply the model on test dataset
pred_log <- predict(model_log,
newdata = test,
type='response')
res_log <- ifelse(pred_log > 0.5, 1, 0)
# Evaluate the performance of models
## 1. SVM model
## ROC curve
roc.curve(pred_svm, test$stroke)
## Area under the curve (AUC): 0.918
## plot Confusion Matrix and evaluation metrics
confusionMatrix(pred_svm,
factor(test$stroke),
mode = "everything",
positive = "1")
## Confusion Matrix and Statistics
##
## Reference
## Prediction 0 1
## 0 1185 27
## 1 225 1373
##
## Accuracy : 0.9103
## 95% CI : (0.8991, 0.9206)
## No Information Rate : 0.5018
## P-Value [Acc > NIR] : < 2.2e-16
##
## Kappa : 0.8207
##
## Mcnemar's Test P-Value : < 2.2e-16
##
## Sensitivity : 0.9807
## Specificity : 0.8404
## Pos Pred Value : 0.8592
## Neg Pred Value : 0.9777
## Precision : 0.8592
## Recall : 0.9807
## F1 : 0.9159
## Prevalence : 0.4982
## Detection Rate : 0.4886
## Detection Prevalence : 0.5687
## Balanced Accuracy : 0.9106
##
## 'Positive' Class : 1
##
## 2. KNN model
## ROC curve
roc.curve(pred_knn, test_knn$stroke)
## Area under the curve (AUC): 0.913
## plot Confusion Matrix and evaluation metrics
confusionMatrix(pred_knn,
factor(test_knn$stroke),
mode = "everything",
positive="1")
## Confusion Matrix and Statistics
##
## Reference
## Prediction 1 2
## 1 1166 27
## 2 244 1373
##
## Accuracy : 0.9036
## 95% CI : (0.892, 0.9142)
## No Information Rate : 0.5018
## P-Value [Acc > NIR] : < 2.2e-16
##
## Kappa : 0.8072
##
## Mcnemar's Test P-Value : < 2.2e-16
##
## Sensitivity : 0.8270
## Specificity : 0.9807
## Pos Pred Value : 0.9774
## Neg Pred Value : 0.8491
## Precision : 0.9774
## Recall : 0.8270
## F1 : 0.8959
## Prevalence : 0.5018
## Detection Rate : 0.4149
## Detection Prevalence : 0.4246
## Balanced Accuracy : 0.9038
##
## 'Positive' Class : 1
##
## 3. Logistic Regression model
## ROC curve
roc.curve(res_log, test$stroke)
## Area under the curve (AUC): 0.773
## plot Confusion Matrix and evaluation metrics
confusionMatrix(factor(res_log),
factor(test$stroke),
mode = "everything",
positive="1")
## Confusion Matrix and Statistics
##
## Reference
## Prediction 0 1
## 0 998 237
## 1 412 1163
##
## Accuracy : 0.769
## 95% CI : (0.753, 0.7845)
## No Information Rate : 0.5018
## P-Value [Acc > NIR] : < 2.2e-16
##
## Kappa : 0.5383
##
## Mcnemar's Test P-Value : 8.486e-12
##
## Sensitivity : 0.8307
## Specificity : 0.7078
## Pos Pred Value : 0.7384
## Neg Pred Value : 0.8081
## Precision : 0.7384
## Recall : 0.8307
## F1 : 0.7818
## Prevalence : 0.4982
## Detection Rate : 0.4139
## Detection Prevalence : 0.5605
## Balanced Accuracy : 0.7693
##
## 'Positive' Class : 1
##